-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Model post process for zero stage3 training #17187
Conversation
orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py
Outdated
Show resolved
Hide resolved
orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
86704cf
to
4e59594
Compare
tensor_input_dtypes: List[torch.onnx.TensorProtoDataType], | ||
) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]: | ||
# output = input.matmul(weight.t()) | ||
tensor_input_shapes[0] # input |
Check notice
Code scanning / CodeQL
Statement has no effect
…pengwa/zero_post_process
orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Thank you a lot @askhade!!! |
### Model post process for zero stage3 training This is the last change to make single GPU/Multiple GPUs run pass. Design details: https://microsoft.sharepoint.com/:p:/t/ONNX2/EfNfJ43necpIoPI6x5M2zvYBVbfjoPQmG4Boc_F7-tHm1w?e=ekQwA6&nav=eyJzSWQiOjMxNiwiY0lkIjoxMDE1Nzg3NDZ9 `PyTorch` runs with ZeROOffloadSubscriber: ``` model = prepare_model(...) from onnxruntime.training.utils.hooks import configure_ort_compatible_zero_stage3 configure_ort_compatible_zero_stage3() ``` `ORTModule` runs with ZeROOffloadSubscriber: ``` os.environ['ORTMODULE_ENABLE_ZERO_STAGE3'] = '1' from onnxruntime.training.ortmodule import ORTModule model = ORTModule(self.model) ``` It will be fairly easy to debug convergence issue if both ORT and PyTorch can run the same offload path. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
Model post process for zero stage3 training
This is the last change to make single GPU/Multiple GPUs run pass.
Design details: https://microsoft.sharepoint.com/:p:/t/ONNX2/EfNfJ43necpIoPI6x5M2zvYBVbfjoPQmG4Boc_F7-tHm1w?e=ekQwA6&nav=eyJzSWQiOjMxNiwiY0lkIjoxMDE1Nzg3NDZ9
PyTorch
runs with ZeROOffloadSubscriber:ORTModule
runs with ZeROOffloadSubscriber:It will be fairly easy to debug convergence issue if both ORT and PyTorch can run the same offload path.
Motivation and Context